iT邦幫忙

2024 iThome 鐵人賽

DAY 29
0
AI/ ML & Data

基於人工智慧與深度學習對斑馬魚做行為分析系列 第 29

day 29 基於人工智慧與深度學習對斑馬魚做行為分析

  • 分享至 

  • xImage
  •  

今天是第29天我們可以寫一個基於人工智慧與深度學習對斑馬魚做行為分析的系統,以下是我寫得最有效率的程式碼

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import cv2
import numpy as np
from yolov8 import YOLOv8
from sort import Sort

# 自定義 CNN 特徵提取器
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(256 * 16 * 16, 1024)
        self.fc2 = nn.Linear(1024, 512)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 256 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

# 多層 LSTM 模型
class MultiLayerLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=3):
        super(MultiLayerLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
        c0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
        
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

# YOLOv8 模型初始化
yolo_model = YOLOv8("yolov8_weights.pth")

# SORT 跟踪器初始化
sort_tracker = Sort()

# 初始化 CNN 特徵提取器
cnn_feature_extractor = CNNFeatureExtractor()

# LSTM 模型參數
input_size = 512  # 來自 CNN 提取的特徵向量
hidden_size = 1024
output_size = 5  # 行為分類數量
num_layers = 3

# 初始化 LSTM 模型
lstm_model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
lstm_model.load_state_dict(torch.load("lstm_model.h5"))
lstm_model.eval()

# 損失函數與優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(cnn_feature_extractor.parameters()) + list(lstm_model.parameters()), lr=0.0001)

# 使用 GPU 設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_feature_extractor.to(device)
lstm_model.to(device)

# 視頻處理
cap = cv2.VideoCapture("zebrafish_video.mp4")

def extract_cnn_features(bbox, frame):
    """
    從給定的邊框和圖像框架中提取 CNN 特徵
    :param bbox: 邊框 (x1, y1, x2, y2)
    :param frame: 當前的圖像幀
    :return: 提取的 CNN 特徵數據
    """
    x1, y1, x2, y2 = bbox
    roi = frame[y1:y2, x1:x2]
    roi_resized = cv2.resize(roi, (64, 64))
    
    # 數據增強: 隨機翻轉、旋轉
    if np.random.rand() > 0.5:
        roi_resized = cv2.flip(roi_resized, 1)
    angle = np.random.uniform(-10, 10)
    M = cv2.getRotationMatrix2D((32, 32), angle, 1.0)
    roi_resized = cv2.warpAffine(roi_resized, M, (64, 64))
    
    roi_tensor = torch.tensor(roi_resized, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
    
    with torch.no_grad():
        cnn_features = cnn_feature_extractor(roi_tensor)
    
    return cnn_features

def visualize_results(frame, bbox, behavior_class):
    """
    將結果可視化,將行為分類結果疊加在視頻幀上。
    :param frame: 當前的圖像幀
    :param bbox: 邊框 (x1, y1, x2, y2)
    :param behavior_class: 行為分類的標籤
    """
    x1, y1, x2, y2 = bbox
    color = (0, 255, 0)
    cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
    
    label = f"Behavior: {behavior_class}"
    cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # 斑馬魚檢測
    detections = yolo_model.detect(frame)
    
    # 斑馬魚跟蹤
    tracks = sort_tracker.update(detections)
    
    sequence_features = []
    
    for track in tracks:
        fish_id, bbox, _ = track
        cnn_features = extract_cnn_features(bbox, frame)
        
        sequence_features.append(cnn_features.squeeze(0).cpu().numpy())
    
    if len(sequence_features) > 0:
        sequence_features_tensor = torch.tensor(sequence_features, dtype=torch.float32).unsqueeze(0).to(device)
        
        # LSTM 模型預測
        with torch.no_grad():
            behavior_output = lstm_model(sequence_features_tensor)
        
        # 獲取行為分類結果
        _, behavior_class = torch.max(behavior_output, 1)
        behavior_class = behavior_class.item()
        
        # 可視化結果
        for i, track in enumerate(tracks):
            visualize_results(frame, track[1], behavior_class)
    
    cv2.imshow("Advanced Zebrafish Behavior Analysis", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

1. 自定義 CNN 特徵提取器

class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(256 * 16 * 16, 1024)
        self.fc2 = nn.Linear(1024, 512)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 256 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

解釋:

  • 這個類別是一個自定義的卷積神經網絡(CNN)特徵提取器。
  • 它包含三個卷積層,每個卷積層後接一個ReLU激活函數和池化層。池化層用來減少特徵圖的尺寸。
  • 最後,通過兩個全連接層將提取的特徵進一步降維到512維特徵向量。
  • forward 函數是定義前向傳播過程的地方,數據流經卷積層、池化層,最終經過全連接層。

2. 多層 LSTM 模型

class MultiLayerLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=3):
        super(MultiLayerLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
        c0 = torch.zeros(num_layers, x.size(0), hidden_size).to(device)
        
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

解釋:

  • 這個類定義了一個多層的LSTM模型,用於處理時序數據。LSTM(長短期記憶網絡)能夠捕捉長期依賴性,非常適合處理序列數據,如斑馬魚的行為序列。
  • input_size 是每個時間步輸入的特徵數量(即來自CNN的特徵向量大小)。
  • hidden_size 是LSTM單元的隱藏層大小,output_size 是最終行為分類的類別數。
  • forward 函數執行前向傳播,LSTM的初始隱藏狀態和細胞狀態都初始化為零向量,然後輸入數據進行LSTM計算,最後通過全連接層輸出分類結果。

3. 視頻處理與行為分析主邏輯

# YOLOv8 模型初始化
yolo_model = YOLOv8("yolov8_weights.pth")

# SORT 跟踪器初始化
sort_tracker = Sort()

# 初始化 CNN 特徵提取器
cnn_feature_extractor = CNNFeatureExtractor()

# LSTM 模型參數
input_size = 512  # 來自 CNN 提取的特徵向量
hidden_size = 1024
output_size = 5  # 行為分類數量
num_layers = 3

# 初始化 LSTM 模型
lstm_model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
lstm_model.load_state_dict(torch.load("lstm_model.h5"))
lstm_model.eval()

# 損失函數與優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(cnn_feature_extractor.parameters()) + list(lstm_model.parameters()), lr=0.0001)

# 使用 GPU 設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_feature_extractor.to(device)
lstm_model.to(device)

# 視頻處理
cap = cv2.VideoCapture("zebrafish_video.mp4")

解釋:

  • 首先初始化 YOLOv8 模型來檢測斑馬魚,SORT 跟踪器用於跟踪斑馬魚的軌跡。
  • 初始化 CNN 特徵提取器和 LSTM 模型,並將模型加載到 GPU 上(如果可用)。
  • criterion 是交叉熵損失函數,optimizer 是Adam優化器,用於訓練模型。
  • cap 用於讀取視頻文件。

4. 特徵提取與增強

def extract_cnn_features(bbox, frame):
    x1, y1, x2, y2 = bbox
    roi = frame[y1:y2, x1:x2]
    roi_resized = cv2.resize(roi, (64, 64))
    
    # 數據增強: 隨機翻轉、旋轉
    if np.random.rand() > 0.5:
        roi_resized = cv2.flip(roi_resized, 1)
    angle = np.random.uniform(-10, 10)
    M = cv2.getRotationMatrix2D((32, 32), angle, 1.0)
    roi_resized = cv2.warpAffine(roi_resized, M, (64, 64))
    
    roi_tensor = torch.tensor(roi_resized, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
    
    with torch.no_grad():
        cnn_features = cnn_feature_extractor(roi_tensor)
    
    return cnn_features

解釋:

  • 這個函數負責從圖像幀中提取斑馬魚的ROI,並將其轉換為CNN的輸入。
  • roi_resized 是將ROI調整到64x64像素的大小。
  • 隨機翻轉和旋轉是數據增強技術,這些操作增加了模型的魯棒性。
  • 最後,將ROI轉換為Tensor格式,並輸入到CNN中提取深度特徵。

5. 可視化結果與行為分析主循環

def visualize_results(frame, bbox, behavior_class):
    x1, y1, x2, y2 = bbox
    color = (0, 255, 0)
    cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
    
    label = f"Behavior: {behavior_class}"
    cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # 斑馬魚檢測
    detections = yolo_model.detect(frame)
    
    # 斑馬魚跟蹤
    tracks = sort_tracker.update(detections)
    
    sequence_features = []
    
    for track in tracks:
        fish_id, bbox, _ = track
        cnn_features = extract_cnn_features(bbox, frame)
        
        sequence_features.append(cnn_features.squeeze(0).cpu().numpy())
    
    if len(sequence_features) > 0:
        sequence_features_tensor = torch.tensor(sequence_features, dtype=torch.float32).unsqueeze(0).to(device)
        
        # LSTM 模型預測
        with torch.no_grad():
            behavior_output = lstm_model(sequence_features_tensor)
        
        # 獲取行為分類結果
        _, behavior_class = torch.max(behavior_output, 1)
        behavior_class = behavior_class.item()
        
        # 可視化結果
        for i, track in enumerate(tracks):
            visualize_results(frame, track[1], behavior_class)
    
    cv2.imshow("Advanced Zebrafish Behavior Analysis", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

解釋:

  • visualize_results 用於將行為分類結果疊加到視頻幀上,在畫面上顯示斑馬

魚的行為類別。

  • while cap.isOpened() 是主循環,負責處理視頻中的每一幀。
  • 使用 YOLOv8 檢測斑馬魚,然後用 SORT 跟踪器跟踪檢測到的斑馬魚。
  • 對每個跟踪到的斑馬魚,提取其CNN特徵,並構建一個特徵序列。
  • LSTM 模型基於這個特徵序列來預測斑馬魚的行為,並將結果可視化。
  • 最後,顯示帶有行為分析結果的視頻,如果按下q鍵則退出。

這個程式碼結合了物體檢測、跟踪、特徵提取、行為分類和結果可視化,形成了一個完整的斑馬魚行為分析系統。模型的複雜性體現在自定義的CNN、使用多層LSTM捕捉時序依賴、以及數據增強和GPU加速上。


上一篇
day 28 基於人工智慧與深度學習斑馬魚行為分析
下一篇
day 30 基於人工智慧與深度學習對斑馬魚做行為分析
系列文
基於人工智慧與深度學習對斑馬魚做行為分析30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言